import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from skfda.representation.grid import FDataGrid
from skfda.representation.basis import FourierBasis, BSplineBasis
from skfda.ml.regression import LinearRegression


def ipw_estimator(train_sample, est_sample, return_inv_weight=False):
    """_summary_

    Args:
        train_sample (list): Sample used for fitting nuisance components. Triplet of (phi_tr, A_tr, X_tr).
                                phi_tr: Collection of silhouette functions. Shape: [n_tr, n_hom_dim, resolution].
                                A_tr: Treatment. Shape: [n_tr,].
                                X_tr: Covariates of dimension d. Shape: [n_tr, d].
        est_sample (list): Sample used for estimation. Triplet of (phi_est, A_est, X_est).
                                phi_est: Collection of silhouette functions. Shape: [n_est, n_hom_dim, resolution].
                                A_est: Treatment. Shape: [n_est,].
                                X_est: Covariates of dimension d. Shape: [n_est, d].
        return_inv_weight (bool, optional): Wheter to return inverse weight values. Defaults to False.

    Returns:
        (list): List containing "n_hom_dim" IPW estimates of shape [resolution, ].
    """
    _, A_tr, X_tr = train_sample
    phi_est, A_est, X_est = est_sample
    
    n_hom_dim = phi_est.shape[-2]   # number of homology dimensions
    
    # fit propensity score on training sample
    # logi_reg = LogisticRegression(penalty=None)
    # logi_reg.fit(X_tr, A_tr)
    # pi_hat = logi_reg.predict_proba(X_est)[:,1]     # propensity score
    
    rf = RandomForestClassifier()
    rf.fit(X_tr, A_tr)
    pi_hat = rf.predict_proba(X_est)[:,1]   # propensity score

    # avoid 0 or 1 propensity score
    pi_hat[pi_hat == 0] = 1e-2   
    pi_hat[pi_hat == 1] = 1-1e-2


    # construct ipw estimator on estimation sample
    inv_weight = (A_est/pi_hat - (1-A_est)/(1-pi_hat))[:, np.newaxis]   # shape: [n_est, 1] 
    
    ipw = []
    for hom_dim in range(n_hom_dim):
        ipw.append(np.mean(inv_weight * phi_est[:, hom_dim, :], axis=0))
    
    if return_inv_weight:
        return ipw, inv_weight
    return ipw


def plugin_estimator(train_sample, est_sample, tseq, n_basis=10, return_mu=False):
    """_summary_

    Args:
        train_sample (list): Sample used for fitting nuisance components. Triplet of (phi_tr, A_tr, X_tr).
                                phi_tr: Collection of silhouette functions. Shape: [n_tr, n_hom_dim, resolution].
                                A_tr: Treatment. Shape: [n_tr, 1].
                                X_tr: Covariates of dimension d. Shape: [n_tr, d].
        est_sample (list): Sample used for estimation. Triplet of (phi_est, A_est, X_est).
                                phi_est: Collection of silhouette functions. Shape: [n_est, n_hom_dim, resolution].
                                A_est: Treatment. Shape: [n_est, 1].
                                X_est: Covariates of dimension d. Shape: [n_est, d].
        tseq (np.ndarray): Grids points.
        n_basis (int): Number of basis to use.
        return_mu (bool, optional): Wheter to return regression function values. Defaults to False.

    Returns:
        (list): List containing "n_hom_dim" plug-in estimates of shape [resolution, ].
    """
    phi_tr, A_tr, X_tr = train_sample
    _, _, X_est = est_sample

    n_hom_dim = phi_tr.shape[-2]                                        # number of homology dimensions
    basis = FourierBasis(domain_range=tseq[[0, -1]], n_basis=n_basis)   # basis used for function-on-scalar regression
    
    # split training sample into control and treated group
    ind_tr = A_tr.astype(bool)      # indicator of treated individuals in train sample
    X_tr0 = X_tr[~ind_tr]           # covariates of control group in train sample
    X_tr1 = X_tr[ind_tr]            # covariates of treated group in train sample

    # fit and estimate regression function
    mu0_list = []   # list containing regression functions of each homology dimension for control group
    mu1_list = []   # list containing regression functions of each homology dimension for treated group
    for hom_dim in range(n_hom_dim):
        # control group
        phi0 = phi_tr[~ind_tr, hom_dim, :]                      # Silhouette function in "hom_dim"-dimesion for control group.
        phi0_f = FDataGrid(data_matrix=phi0, grid_points=tseq)  # representation of functional data as a set of curves discretised in a grid of points
        phi0_fd = phi0_f.to_basis(basis)                        # functional data form
        func_reg0 = LinearRegression(fit_intercept=True)        # functional regression
        func_reg0.fit(pd.DataFrame(X_tr0), phi0_fd)             # fit
        phi0_pred = func_reg0.predict(pd.DataFrame(X_est))      # estimate
        mu0_list.append(phi0_pred(tseq).squeeze())              # shape: [n_est, res]

        # treated group
        phi1 = phi_tr[ind_tr, hom_dim, :]                       # Silhouette function in "hom_dim"-dimesion for treated group.
        phi1_f = FDataGrid(data_matrix=phi1, grid_points=tseq)  # representation of functional data as a set of curves discretised in a grid of points
        phi1_fd = phi1_f.to_basis(basis)                        # functional data form
        func_reg1 = LinearRegression(fit_intercept=True)        # functional regression
        func_reg1.fit(pd.DataFrame(X_tr1), phi1_fd)             # fit
        phi1_pred = func_reg1.predict(pd.DataFrame(X_est))      # estimate
        mu1_list.append(phi1_pred(tseq).squeeze())              # shape: [n_est, res]
        
    plugin = [np.mean(mu1 - mu0, axis=0) for mu1, mu0 in zip(mu1_list, mu0_list)]

    if return_mu:
        return plugin, mu0_list, mu1_list
    return plugin


def aipw_estimator(train_sample, est_sample, tseq, n_basis=10):
    """_summary_

    Args:
        train_sample (list): Sample used for fitting nuisance components. Triplet of (phi_tr, A_tr, X_tr).
                                phi_tr: Collection of silhouette functions. Shape: [n_tr, n_hom_dim, resolution].
                                A_tr: Treatment. Shape: [n_tr, 1].
                                X_tr: Covariates of dimension d. Shape: [n_tr, d].
        est_sample (list): Sample used for estimation. Triplet of (phi_est, A_est, X_est).
                                phi_est: Collection of silhouette functions. Shape: [n_est, n_hom_dim, resolution].
                                A_est: Treatment. Shape: [n_est, 1].
                                X_est: Covariates of dimension d. Shape: [n_est, d].
        tseq (np.ndarray): Grids points.
        n_basis (int): Number of basis to use.
    
    Returns:
        (list): List containing "n_hom_dim" doubly robust estimates of shape [resolution, ].
    """
    phi_tr, _, _ = train_sample
    _, A_est, _ = est_sample
    
    n_hom_dim = phi_tr.shape[-2]       # number of homology dimensions

    ipw, inv_weight = ipw_estimator(train_sample, est_sample, return_inv_weight=True)
    plugin, mu0_list, mu1_list = plugin_estimator(train_sample, est_sample, tseq, n_basis, return_mu=True)

    dr =[]
    A = A_est[:, np.newaxis]
    for hom_dim in range(n_hom_dim):
        correction = ipw[hom_dim] - np.mean(inv_weight*(A*mu1_list[hom_dim] + (1-A)*mu0_list[hom_dim]), axis=0)
        dr.append(plugin[hom_dim] + correction)
    return dr